from __future__ import annotations

import json
import re
from pathlib import Path

import attr

from mentat.errors import SampleError
from mentat.sampler import __version__


@attr.define
class Sample:
    # TODO: enforce required fields
    title: str = attr.field(default="")
    description: str = attr.field(default="")
    id: str = attr.field(default="")
    parent_id: str = attr.field(default="")
    repo: str = attr.field(default="")
    environment_setup_commit: str | None = attr.field(default=None)
    merge_base: str | None = attr.field(default=None)
    diff_merge_base: str = attr.field(default="")
    diff_active: str = attr.field(default="")
    message_history: list[dict[str, str]] = attr.field(default=[])  # type: ignore
    message_prompt: str = attr.field(default="")
    hint_text: str = attr.field(default="")
    message_edit: str = attr.field(default="")
    context: list[str] = attr.field(default=[])  # type: ignore
    diff_edit: str = attr.field(default="")
    test_patch: str = attr.field(default="")
    FAIL_TO_PASS: str = attr.field(default="")
    PASS_TO_PASS: str = attr.field(default="")
    version: str = attr.field(default=__version__)

    def save(self, fname: str | Path) -> None:
        with open(Path(fname), "w") as f:
            json.dump(attr.asdict(self), f, indent=4)

    @classmethod
    def load(cls, fname: str | Path) -> Sample:
        with open(fname, "r") as f:
            kwargs = json.load(f)
            _version = kwargs.get("version")
            if _version and _version < "0.2.0":
                kwargs["message_history"] = kwargs.get("message_history", [])[::-1]
                kwargs["version"] = "0.2.0"
                _version = kwargs["version"]
            if _version < "0.3.0":
                # Additional fields from SWE-Bench
                kwargs["environment_setup_commit"] = ""
                kwargs["hint_text"] = ""
                kwargs["test_patch"] = ""
                if "test_command" in kwargs:
                    kwargs["FAIL_TO_PASS"] = json.dumps([kwargs["test_command"]])
                    del kwargs["test_command"]
                else:
                    kwargs["FAIL_TO_PASS"] = ""
                kwargs["PASS_TO_PASS"] = ""
                kwargs["version"] = "0.3.0"
                _version = kwargs["version"]
            if _version != __version__:
                raise SampleError(
                    f"Warning: sample version ({_version}) does not match current" f" version ({__version__})."
                )
            return cls(**kwargs)

    @classmethod
    def from_swe_bench(cls, benchmark: dict[str, str]) -> Sample:
        """Create a Sample from a SWE-Bench benchmark.

        SWE-Bench Fields (https://huggingface.co/datasets/princeton-nlp/SWE-bench#dataset-structure)
        - instance_id: (str) - A formatted instance identifier, usually as repo_owner__repo_name-PR-number.
        - patch: (str) - The gold patch, the patch generated by the PR (minus test-related code), that resolved
          the issue.
        - repo: (str) - The repository owner/name identifier from GitHub.
        - base_commit: (str) - The commit hash of the repository representing the HEAD of the repository before the
          solution PR is applied.
        - hints_text: (str) - Comments made on the issue prior to the creation of the solution PR’s first commit
          creation date.
        - created_at: (str) - The creation date of the pull request.
        - test_patch: (str) - A test-file patch that was contributed by the solution PR.
        - problem_statement: (str) - The issue title and body.
        - version: (str) - Installation version to use for running evaluation.
        - environment_setup_commit: (str) - commit hash to use for environment setup and installation.
        - FAIL_TO_PASS: (str) - A json list of strings that represent the set of tests resolved by the PR and tied to
          the issue resolution.
        - PASS_TO_PASS: (str) - A json list of strings that represent tests that should pass before and after the PR
          application.
        """
        patch = benchmark.get("patch", "")
        edited_files = re.findall(r"diff --git a/(.*?) b/\1", patch)
        return cls(
            title=f"SWE-bench-{benchmark['instance_id']}",
            description="",
            id=benchmark["instance_id"],
            parent_id="",
            repo=f"https://github.com/{benchmark.get('repo')}",
            environment_setup_commit=benchmark.get("environment_setup_commit", ""),
            merge_base=benchmark.get("base_commit"),
            diff_merge_base="",
            diff_active="",
            message_history=[],
            message_prompt=benchmark.get("problem_statement", ""),
            hint_text=benchmark.get("hint_text", ""),
            message_edit="",
            context=edited_files,
            diff_edit=patch,
            test_patch=benchmark.get("test_patch", ""),
            FAIL_TO_PASS=benchmark.get("FAIL_TO_PASS", ""),
            PASS_TO_PASS=benchmark.get("PASS_TO_PASS", ""),
        )
